import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
# from itertools import chain
# from docutils.nodes import legend

torch.manual_seed(43)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def function(x, y):
    return 0.5 * x ** 2 + 0.5 *  y**2

hidden_dim = 128
p_vector0 = torch.normal(0, 1, size=(1, 2 * hidden_dim)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

p_vector1 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

p_vector2 = torch.normal(0, 1, size=(1, hidden_dim // 2)).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize


class Net(nn.Module):
    def __init__(self, device):
        super(Net, self).__init__()
        self.theta = 3.0

        self.layers = []
        dims = [4, 2 * hidden_dim, 2 * hidden_dim, hidden_dim, hidden_dim, hidden_dim // 2]
        for d in range(0, (len(dims)), 2):  # For three layers
            self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(device)]

    def predict(self, x,y):
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        n_z = 300  # take 1000 points in y direction at given x coordinate
        x = x.repeat(n_z, 1)
        y = y.repeat(n_z, 1)
        z = torch.linspace(-10,30, n_z).view(n_z, 1).to(device)
        x_which = torch.cat((x, y, z), 1)


        h = x_which
        goodness_per_label = []
        for label in [0.0, 1.0]:  # we have two labels only 1 and -1 for intol and outtol data
            if x_which.shape[1] == 3:
                x_which = torch.cat((x_which, torch.ones_like((x_which[:, 0].unsqueeze(1))) * label), 1)
            else:
                x_which[:, -1] = torch.ones_like((x_which[:, 0])) * label


            goodness = []
            for k, layer in enumerate(self.layers):
                if k == 0:
                    g = layer(x_which, k)
                    goodness += [cos(g, p_vector0.repeat(g.shape[0], 1))]

                if k == 1:
                    g = layer(g, k)
                    goodness += [cos(g, p_vector1.repeat(g.shape[0], 1))]

                if k == 2:
                    g = layer(g, k)
                    goodness += [cos(g, p_vector2.repeat(g.shape[0], 1))]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)  # shape= (n_y,2)
        mask = goodness_per_label[:, 1] < goodness_per_label[:, 0]
        return z[mask]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def train_model(self, dataloader,n):
        k = 0
        for i, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            x_pos = torch.cat(
                (x[: x.shape[0]//2], y[:x.shape[0]//2].unsqueeze(1)), 1
            )  # positive data
            x_neg = torch.cat(
                (x[x.shape[0]//2: ], y[x.shape[0]//2:].unsqueeze(1)), 1
            )  # negative data

            h_pos, h_neg = x_pos, x_neg

            for layer in self.layers:
                h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
                k += 1
                print("Layer", k, "Loss", loss)


class Layer(nn.Module):
    def __init__(self, in_features, out_features, theta):
        super(Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.theta = theta
        self.layer = nn.Linear(in_features, out_features)
        self.main = nn.Sequential(self.layer, nn.GELU())
        self.layer_epochs = n_epochs
        self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001, eps=1e-10)
    def forward(self, x, k):
        x = x.view(-1, self.in_features)
        return self.main(x)

    def goodness(self, x_pos, x_neg, k):
        h_pos = self.forward(x_pos, k)  # positive always coz of abs
        h_neg = self.forward(x_neg, k)  # negative always coz of abs

        cos = nn.CosineSimilarity()

        if k == 0:
            g_pos = cos(h_pos, p_vector0.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector0.repeat(h_neg.shape[0], 1))

        elif k == 1:
            g_pos = cos(h_pos, p_vector1.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector1.repeat(h_neg.shape[0], 1))

        elif k == 2:
            g_pos = cos(h_pos, p_vector2.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector2.repeat(h_neg.shape[0], 1))

        return g_pos, g_neg

    def train(self, x_pos, x_neg, k):
        self.running_loss = 0.0
        for i in range(self.layer_epochs):
            g_pos, g_neg = self.goodness(x_pos, x_neg, k)
            delta = (g_pos - g_neg)
            # offset = 0.8
            # scale = 1.0
            # loss = (-1 / (((delta - offset) / scale) ** 2 + (
            #             delta - offset) / scale + 1)).mean()  # (torch.log(1 + torch.exp(- self.theta * (delta)))).mean()
            loss = (torch.exp(self.theta * delta)).mean()

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            self.running_loss += loss.item()
        return (
            self.forward(x_pos, k).detach(),
            self.forward(x_neg, k).detach(),
            (g_pos - g_neg).mean()
            # self.running_loss / self.layer_epochs,
        )


def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)


def create_3d_dataset(x_range, y_range, N, tol, n_intol, n_outtol):
    X = torch.linspace(x_range[0], x_range[1], N)#.repeat(n_intol,1) # N x N
    Y = torch.linspace(y_range[0], y_range[1], N)#.repeat(n_intol,1) # N x N
    X, Y = torch.meshgrid(X,Y,indexing='ij')
    Z = function(X,Y)
    z_in_tol = Z.repeat(n_intol,1) + torch.randn(Z.repeat(n_intol,1).shape) * tol
    in_tol_points = torch.cat((X.repeat(n_intol,1),Y.repeat(n_intol,1),z_in_tol)).view(3,-1) # (3,100)

    ceiling = z_in_tol.max()
    floor = z_in_tol.min()

    z_out_tol = []
    Z_flat = Z.flatten()
    for i in range(len(Z_flat)):
        z_out_tol.append(torch.cat((torch.linspace(Z_flat[i]+tol,ceiling,n_outtol//2),torch.linspace(floor,Z_flat[i]-tol,n_outtol//2))))
    z_out_tol = torch.stack(z_out_tol,dim=1)
    out_tol_points = torch.cat((X.repeat(n_outtol,1),Y.repeat(n_outtol,1),z_out_tol.reshape(X.repeat(n_outtol,1).shape))).view(3,-1)
    # plt.figure(figsize=(12, 8))
    # ax = plt.axes(projection='3d')
    # ax.plot_surface(X,Y,Z,  color='yellow')
    # ax.scatter(in_tol_points[0,:],in_tol_points[1,:],in_tol_points[2,:],  color='green')  # Convert to numpy for plotting
    # ax.scatter(out_tol_points[0,:],out_tol_points[1,:],out_tol_points[2,:],  color='red')  # Convert to numpy for plotting
    # ax.set_xlabel('X')
    # ax.set_ylabel('Y')
    # ax.set_zlabel('Z')
    # plt.show()
    return in_tol_points,out_tol_points

if __name__ == '__main__':
    N = 25
    n_epochs= 5000
    x_range = (-5,5)
    y_range = (-5,5)
    tol = 2
    n_intol = 30
    n_outtol = 50


    def append_label(tensor, value):
        label_row = torch.full((1, tensor.shape[1]), value)
        return torch.cat((tensor, label_row), dim=0)

    # Create dataset
    in_tol_points, out_tol_points = create_3d_dataset(x_range, y_range, N, tol, n_intol, n_outtol)

    # Generate positive and negative data
    positive_data = torch.cat((append_label(in_tol_points, 1), append_label(out_tol_points, 0)), dim=1)
    negative_data = torch.cat((append_label(in_tol_points, 0), append_label(out_tol_points, 1)), dim=1)
    dataset = torch.cat((positive_data, negative_data), 1)

    dataset = torch.utils.data.TensorDataset(dataset[:3, :].T, dataset[3, :])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=negative_data.shape[1]+positive_data.shape[1], shuffle=False)

    model = Net(device)
    model.train_model(dataloader,N**2)
    torch.save(model, 'model.pth')
    # model = torch.load('model.pth',weights_only=False)
    model.eval()

    x = torch.linspace(-5, 5, 100).to(device)
    y = torch.linspace(-5, 5, 100).to(device)
    x,y = torch.meshgrid(x,y,indexing='ij')
    z_original = function(x, y)
    z_pred = torch.zeros_like(z_original).reshape(x.shape)
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            z_pred[i, j] = model.predict(x[i, j], y[i, j]).mean()
    
    # Plotting the predicted surface
    plt.figure(figsize=(12, 8))
    ax = plt.axes(projection='3d')
    ax.plot_surface(x.cpu().numpy(), y.cpu().numpy(), z_pred.cpu().numpy(), color='yellow',label="Predicted Plane")
    ax.plot_surface(x.cpu().numpy(), y.cpu().numpy(), z_original.cpu().cpu().numpy(), color='green',label="Actual Plane")  # Convert to numpy for plotting
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.legend()
    plt.show()

    
    



